library(tidyverse)
library(tidymodels)
library(workflows)
library(magrittr)
source('data_retrieval.R')
source('preprocess.R')
source('modeling.R')
source('data_retrieval.R')

rm(roc_auc)

train_monthly_data <- function(con, sample_name, hyperparameter_search_iter = 100) {
  print(paste0("started training ", sample_name, " sample"))
  revision_cancer_monthly <- get_revision_cancer_monthly_data(con)
  tic <- Sys.time()
  set.seed(42)
  # download and preprocess datasets
  (revision_cancer_monthly_split <- 
    revision_cancer_monthly 
    %>% filter(S_sample_source_XX == sample_name)
    %>% add_UTL_features()
    %>% eol_cleaning_pipeline()
    %>% join_phat0(file_name = "revision_p0.RData")
    %>% create_hashed_split(prop = 0.5)
  )
  print(paste0("Data retrieval took ", Sys.time() - tic))
  
  # prepare model spec and calibration
  calibration_proportion <- get_calibration_proportion(revision_cancer_monthly_split$train, DMG_died_within_365d)
  downsampled_train <- downsample_data(revision_cancer_monthly_split$train, DMG_died_within_365d)
  
  (initial_model_params <- list(tree_depth = 8, learn_rate = 0.123001965989824,
                                loss_reduction = 11.1582974204794, min_n = 2.44744561146945,
                                sample_size = 0.773196286475286,
                                mtry = 0.549052911438048, trees = 400)
  )
    
  tic <- Sys.time() 
  best_params <- run_bayes_optimisation(downsampled_train, DMG_died_within_365d, initial_model_params, 50) %>%
                                        select_best() %>% 
                                        select(-.config)

  
  # train model 
  print(paste0("Hyperparamter search took ", difftime(Sys.time(), tic, units = "hours"), " hours!"))
  best_model_spec <- get_xgboost_spec() %>% set_args(nthread = 40, !!!best_params)
  toc <- Sys.time()
  (downsampled_train 
    %>% train_model(best_model_spec, outcome = DMG_died_within_365d)
    -> xgboost_fitted
  )
  
  # get predictions and calibrate them
  print(paste0("Training took ", Sys.time() - toc))
  (preds_test <- 
      get_model_preds(model = xgboost_fitted, 
                      data = revision_cancer_monthly_split$test, 
                      true_outcome = "DMG_died_within_365d")
      %>% calibrate_preds(calibration_proportion, .pred_1)
  )
  
  # ger roc auc and feature improtance
  roc_auc <- roc_auc(preds_test, truth = true_value, .pred_1)
  feature_importance <- get_feature_importance(xgboost_fitted)
  # extract model from workflow object
  fit_xgboost <- pull_workflow_fit(xgboost_fitted)
  results <- list(fit_xgboost = fit_xgboost, preds_test = preds_test, roc_auc = roc_auc, feature_importance = feature_importance)
  save(results,
       file = paste0("revision_monthly_", sample_name, ".RData"))
  return(results)
}

# define vector with monthly sample names
months = c('cnr_data', 'cnr_data2','cnr_data3','cnr_data4','cnr_data5','cnr_data6',
           'cnr_data7','cnr_data8','cnr_data9','cnr_data10','cnr_data11','cnr_data12')

# map through all samples with retrieval-train-evaluate function
monthly_result <- map(months, ~train_monthly_data(con, .x, 100))
# set names as samples
names(monthly_result) <- months

save(monthly_result,
     file = paste0("revision_monthly.RData"))




#################################################

train_monthly_data_train_auc <- function(con, sample_name) {

    print(paste0("started training ", sample_name, " sample"))
  revision_cancer_monthly <- get_revision_cancer_monthly_data(con)
  tic <- Sys.time()
  set.seed(42)
  # download and preprocess datasets
  (revision_cancer_monthly_split <- 
      revision_cancer_monthly 
    %>% filter(S_sample_source_XX == sample_name)
    %>% add_UTL_features()
    %>% eol_cleaning_pipeline()
    %>% join_phat0(file_name = "revision_p0.RData")
    %>% create_hashed_split(prop = 0.5)
  )
  print(paste0("Data retrieval took ", Sys.time() - tic))
  
  # prepare model spec and calibration
  calibration_proportion <- get_calibration_proportion(revision_cancer_monthly_split$train, DMG_died_within_365d)
  downsampled_train <- downsample_data(revision_cancer_monthly_split$train, DMG_died_within_365d)


initial_predict <- get_load_new(file_name =paste0("revision_monthly_",sample_name,".RData"))

recipe <- make_matrix_recipe(downsampled_train, DMG_died_within_365d) %>% prep()
train_data_preprocessed <- recipe %>% juice()


preds <- predict(initial_predict$fit_xgboost,
                 train_data_preprocessed %>% select(-DMG_died_within_365d), 
                 type = "prob")

preds_train_calibrate <- calibrate_preds(preds, calibration_proportion, .pred_1)

roc_auc_train <- roc_auc(preds_train_calibrate ,
                         truth = train_data_preprocessed$DMG_died_within_365d,
                         preds_after_bayes)

res_month <- rbind(data.frame(initial_predict$roc_auc) %>% mutate(split = "test"), 
                   data.frame(roc_auc_train) %>% mutate(split = "train")) %>%
  mutate(sample = sample_name)

return(res_month)
}


# define vector with monthly sample names
months = c('cnr_data', 'cnr_data2','cnr_data3','cnr_data4','cnr_data5','cnr_data6',
           'cnr_data7','cnr_data8','cnr_data9','cnr_data10','cnr_data11','cnr_data12')

# map through all samples with retrieval-train-evaluate function
monthly_auc <- map(months, ~train_monthly_data_train_auc(con, .x))
# set names as samples

write.csv(monthly_auc %>% bind_rows(),
      paste0("monthly_canc_AUC.csv"))








